package ipratelimit

import (
	"runtime/debug"
	"sync"
	"time"

	"gitee.com/clearluo/gotools/util/ipratelimit/internal/typ"

	"gitee.com/clearluo/gotools/util"

	"gitee.com/clearluo/gotools/zaplog"

	"golang.org/x/time/rate"
)

type rateLimiterT struct {
	ipInt      typ.IPInt
	ipLimit    *rate.Limiter // ip限流
	endUnix    int64         // 可清理时间，超过一定时间未使用，将被清理
	blackLimit *rate.Limiter // 触发小黑屋相关
}

func (r *rateLimiterT) update() {
	r.endUnix = time.Now().Add(time.Hour).Unix()
}

func (r *rateLimiterT) ipAllow() bool {
	return r.ipLimit.Allow()
}

func (r *rateLimiterT) blackAllow() bool {
	return r.blackLimit.Allow()
}

type IPRateLimiterT struct {
	ips map[typ.IPInt]*rateLimiterT
	sync.RWMutex
	ipRateLimit    rate.Limit // ip限流令牌产生速率
	ipBucketLen    int        // ip限流令牌桶大小
	blackRateLimit rate.Limit // 小黑屋限流令牌速率
	blackBucketLen int        // 小黑屋令牌桶大小
}

func NewIPRateLimiter(rateSpeed int, bucketLen int, blackTimes int, intervalSecond int) *IPRateLimiterT {
	o := &IPRateLimiterT{
		ips:            make(map[typ.IPInt]*rateLimiterT),
		RWMutex:        sync.RWMutex{},
		ipRateLimit:    rate.Limit(rateSpeed),
		ipBucketLen:    bucketLen,
		blackRateLimit: rate.Limit(float64(blackTimes) / float64(intervalSecond)),
		blackBucketLen: blackTimes,
	}
	o.tickerClean()
	return o
}
func (i *IPRateLimiterT) tickerClean() {
	go func() {
		for {
			func() {
				defer func() {
					if err := recover(); err != nil {
						zaplog.Warn("IPRateLimiterT.tickerClean panic:%v,stack:%v", err, string(debug.Stack()))
					}
				}()
				ticker := time.NewTicker(time.Hour)
				for {
					select {
					case <-ticker.C:
						i.clear()
					}
				}
			}()
			time.Sleep(time.Minute)
		}
	}()
}
func (i *IPRateLimiterT) SetParam(rateSpeed int, bucketLen int, blackTimes int, intervalSecond int) {
	if rateSpeed > 0 {
		i.ipRateLimit = rate.Limit(rateSpeed)
	}
	if bucketLen > 0 {
		i.ipBucketLen = bucketLen
	}
	if blackTimes > 0 && intervalSecond > 0 {
		i.blackRateLimit = rate.Limit(float64(blackTimes) / float64(intervalSecond))
		i.blackBucketLen = blackTimes
	}
}

// 此操作未加锁，需要确保调用前有加锁
func (i *IPRateLimiterT) addIP(ipInt typ.IPInt) *rateLimiterT {
	rateLimterObj := &rateLimiterT{
		ipInt:   ipInt,
		ipLimit: rate.NewLimiter(i.ipRateLimit, i.ipBucketLen),
		//endUnix:    time.Now().Add(time.Hour).Unix(),
		endUnix:    time.Now().Add(time.Second * 30).Unix(), // TODO 测试改30s
		blackLimit: rate.NewLimiter(i.blackRateLimit, i.blackBucketLen),
	}
	i.ips[ipInt] = rateLimterObj
	return rateLimterObj
}
func (i *IPRateLimiterT) Len() int {
	i.RLock()
	defer i.RUnlock()
	return len(i.ips)
}

func (i *IPRateLimiterT) clear() {
	defer util.Profiling("IPRateLimiterT.clear()")()
	i.Lock()
	defer i.Unlock()
	now := time.Now().Unix()
	for ipInt, v := range i.ips {
		if now > v.endUnix {
			delete(i.ips, ipInt)
		}
	}
}
func (i *IPRateLimiterT) IPAllow(ipInt typ.IPInt) bool {
	i.Lock()
	defer i.Unlock()
	limiter, exists := i.ips[ipInt]
	if !exists {
		return i.addIP(ipInt).ipAllow()
	}
	// 续期
	limiter.update()
	return limiter.ipAllow()
}

func (i *IPRateLimiterT) BlackAllow(ipInt typ.IPInt) bool {
	i.Lock()
	defer i.Unlock()
	limiter, exists := i.ips[ipInt]
	if !exists {
		return i.addIP(ipInt).blackAllow()
	}
	// 续期
	limiter.update()
	return limiter.blackAllow()
}
